{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-07-19T01:52:26.653722Z",
     "iopub.status.busy": "2021-07-19T01:52:26.653529Z",
     "iopub.status.idle": "2021-07-19T01:52:50.107349Z",
     "shell.execute_reply": "2021-07-19T01:52:50.106776Z",
     "shell.execute_reply.started": "2021-07-19T01:52:26.653700Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/lyz/.local/lib/python3.6/site-packages/requests/__init__.py:91: RequestsDependencyWarning: urllib3 (1.26.6) or chardet (3.0.4) doesn't match a supported version!\n",
      "  RequestsDependencyWarning)\n"
     ]
    }
   ],
   "source": [
    "# -*- coding: utf-8 -*-\n",
    "import os, sys, glob, argparse\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "\n",
    "import time, datetime\n",
    "import pdb, traceback\n",
    "\n",
    "import cv2\n",
    "from PIL import Image\n",
    "from sklearn.model_selection import train_test_split, StratifiedKFold, KFold\n",
    "\n",
    "import torch\n",
    "import torchvision.models as models\n",
    "import torchvision.transforms as transforms\n",
    "import torchvision.datasets as datasets\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch.autograd import Variable\n",
    "from torch.utils.data.dataset import Dataset\n",
    "import timm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-07-19T01:52:50.108210Z",
     "iopub.status.busy": "2021-07-19T01:52:50.108011Z",
     "iopub.status.idle": "2021-07-19T01:52:53.391951Z",
     "shell.execute_reply": "2021-07-19T01:52:53.390201Z",
     "shell.execute_reply.started": "2021-07-19T01:52:50.108192Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "from joblib import dump, load\n",
    "\n",
    "train_jpg = glob.glob('./Datawhale_人脸情绪识别_数据集/train/*/*')\n",
    "np.random.shuffle(train_jpg)\n",
    "train_jpg = np.array(train_jpg)\n",
    "\n",
    "if os.path.exists('train_jpg.pkl'):\n",
    "    train_jpg_dict = load('train_jpg.pkl')\n",
    "else:\n",
    "    train_jpg_dict = {}\n",
    "    for path in tqdm(train_jpg):\n",
    "        # 加入resize\n",
    "        train_jpg_dict[path] = Image.open(path).convert('RGB')\n",
    "        \n",
    "    dump(train_jpg_dict, 'train_jpg.pkl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-07-19T01:52:53.395262Z",
     "iopub.status.busy": "2021-07-19T01:52:53.395044Z",
     "iopub.status.idle": "2021-07-19T01:52:53.400567Z",
     "shell.execute_reply": "2021-07-19T01:52:53.400165Z",
     "shell.execute_reply.started": "2021-07-19T01:52:53.395244Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "class QRDataset(Dataset):\n",
    "    def __init__(self, img_path, transform=None):\n",
    "        self.img_path = img_path\n",
    "        if transform is not None:\n",
    "            self.transform = transform\n",
    "        else:\n",
    "            self.transform = None\n",
    "    \n",
    "    def __getitem__(self, index):\n",
    "        start_time = time.time()\n",
    "        # img = Image.open(self.img_path[index]).convert('RGB')\n",
    "        \n",
    "        if self.img_path[index] in train_jpg_dict:\n",
    "            img = train_jpg_dict[self.img_path[index]]\n",
    "        else:\n",
    "            img = Image.open(self.img_path[index]).convert('RGB')\n",
    "        \n",
    "        lbl_dict = {'angry': 0,\n",
    "             'disgusted': 1,\n",
    "             'fearful': 2,\n",
    "             'happy': 3,\n",
    "             'neutral': 4,\n",
    "             'sad': 5,\n",
    "             'surprised': 6}\n",
    "        if self.transform is not None:\n",
    "            img = self.transform(img)\n",
    "        \n",
    "        if 'test' in self.img_path[index]:\n",
    "            return img, torch.from_numpy(np.array(0))\n",
    "        else:\n",
    "            lbl_int = lbl_dict[self.img_path[index].split('/')[-2]]\n",
    "            return img, torch.from_numpy(np.array(lbl_int))\n",
    "    \n",
    "    def __len__(self):\n",
    "        return len(self.img_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-07-19T01:52:53.401351Z",
     "iopub.status.busy": "2021-07-19T01:52:53.401187Z",
     "iopub.status.idle": "2021-07-19T01:52:53.472522Z",
     "shell.execute_reply": "2021-07-19T01:52:53.471392Z",
     "shell.execute_reply.started": "2021-07-19T01:52:53.401335Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "class XunFeiNet(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(XunFeiNet, self).__init__()\n",
    "                \n",
    "#         model = models.resnet18(True)\n",
    "#         model.avgpool = nn.AdaptiveAvgPool2d(1)\n",
    "#         model.fc = nn.Linear(512, 7)\n",
    "        \n",
    "        model = timm.create_model('efficientnet_b1', pretrained='imagenet')\n",
    "        model.classifier = nn.Linear(1280, 7)\n",
    "        self.model = model\n",
    "         \n",
    "    def forward(self, img):        \n",
    "        out = self.model(img)\n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-07-19T01:52:53.474738Z",
     "iopub.status.busy": "2021-07-19T01:52:53.474080Z",
     "iopub.status.idle": "2021-07-19T01:52:53.551205Z",
     "shell.execute_reply": "2021-07-19T01:52:53.550778Z",
     "shell.execute_reply.started": "2021-07-19T01:52:53.474690Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def validate(val_loader, model, criterion):\n",
    "    model.eval()\n",
    "    acc1 = []\n",
    "    with torch.no_grad():\n",
    "        end = time.time()\n",
    "        for i, (input, target) in enumerate(val_loader):\n",
    "            input = input.cuda()\n",
    "            target = target.cuda()\n",
    "            \n",
    "            output = model(input)\n",
    "            loss = criterion(output, target)\n",
    "            acc1.append((output.argmax(1) == target).float().mean().item())\n",
    "\n",
    "        print(' * Val Acc@1 {0}'.format(np.mean(acc1)))\n",
    "        return np.mean(acc1)\n",
    "\n",
    "def predict(test_loader, model, tta=10):\n",
    "    model.eval()\n",
    "    \n",
    "    test_pred_tta = None\n",
    "    for _ in range(tta):\n",
    "        test_pred = []\n",
    "        with torch.no_grad():\n",
    "            end = time.time()\n",
    "            for i, (input, target) in enumerate(test_loader):\n",
    "                input = input.cuda()\n",
    "                target = target.cuda()\n",
    "\n",
    "                output = model(input)\n",
    "                output = output.data.cpu().numpy()\n",
    "\n",
    "                test_pred.append(output)\n",
    "        test_pred = np.vstack(test_pred)\n",
    "    \n",
    "        if test_pred_tta is None:\n",
    "            test_pred_tta = test_pred\n",
    "        else:\n",
    "            test_pred_tta += test_pred\n",
    "    \n",
    "    return test_pred_tta\n",
    "\n",
    "def train(train_loader, model, criterion, optimizer, epoch):\n",
    "    model.train()\n",
    "\n",
    "    end = time.time()\n",
    "    acc1 = []\n",
    "    for i, (input, target) in enumerate(train_loader):\n",
    "        input = input.cuda(non_blocking=True)\n",
    "        target = target.cuda(non_blocking=True)\n",
    "        output = model(input)\n",
    "        loss = criterion(output, target)\n",
    "\n",
    "        acc1.append((output.argmax(1) == target).float().mean().item())\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        if i % 100 == 0:\n",
    "            print('Train: {0}'.format(np.mean(acc1)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-07-19T02:17:41.663000Z",
     "iopub.status.busy": "2021-07-19T02:17:41.662403Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Epoch:  0\n",
      "Train: 0.20000000298023224\n",
      "Train: 0.26435644050488377\n",
      "Train: 0.29129353697546084\n",
      "Train: 0.3127907025333853\n",
      "Train: 0.3283042449439107\n",
      "Train: 0.3439121814054286\n",
      "Train: 0.35582363327187233\n",
      "Train: 0.3669044285661909\n",
      "Train: 0.3784020042997993\n",
      "Train: 0.38662597811165844\n",
      "Train: 0.394255751611797\n",
      "Train: 0.40113533923333067\n",
      "Train: 0.4089925141600298\n",
      " * Val Acc@1 0.5065972295124084\n",
      "\n",
      "Epoch:  1\n",
      "Train: 0.45000001788139343\n"
     ]
    }
   ],
   "source": [
    "skf = KFold(n_splits=10, random_state=233, shuffle=True)\n",
    "for flod_idx, (train_idx, val_idx) in enumerate(skf.split(train_jpg, train_jpg)):\n",
    "    \n",
    "    train_loader = torch.utils.data.DataLoader(\n",
    "        QRDataset(train_jpg[train_idx][:],\n",
    "                transforms.Compose([\n",
    "                            transforms.RandomAffine(10),\n",
    "                            transforms.ColorJitter(hue=.05, saturation=.05),\n",
    "                            transforms.RandomHorizontalFlip(),\n",
    "                            transforms.RandomVerticalFlip(),\n",
    "                            transforms.Resize((196, 196)),\n",
    "                            transforms.ToTensor(),\n",
    "                            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
    "            ])\n",
    "        ), batch_size=20, shuffle=True, num_workers=5, pin_memory=True\n",
    "    )\n",
    "    \n",
    "    val_loader = torch.utils.data.DataLoader(\n",
    "        QRDataset(train_jpg[val_idx][:],\n",
    "                transforms.Compose([\n",
    "                            transforms.Resize((196, 196)),\n",
    "#                             transforms.Resize((256, 256)),\n",
    "                            # transforms.Resize((124, 124)),\n",
    "                            # transforms.RandomCrop((88, 88)),\n",
    "                            transforms.ToTensor(),\n",
    "                            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
    "            ])\n",
    "        ), batch_size=10, shuffle=False, num_workers=5, pin_memory=True\n",
    "    )\n",
    "        \n",
    "    model = XunFeiNet().cuda()\n",
    "    criterion = nn.CrossEntropyLoss().cuda()\n",
    "    optimizer = torch.optim.SGD(model.parameters(), 0.01)\n",
    "    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.75)\n",
    "    best_acc = 0.0\n",
    "    for epoch in range(40):\n",
    "        print('\\nEpoch: ', epoch)\n",
    "\n",
    "        train(train_loader, model, criterion, optimizer, epoch)\n",
    "        val_acc = validate(val_loader, model, criterion)\n",
    "        scheduler.step()\n",
    "        \n",
    "        if val_acc > best_acc:\n",
    "            best_acc = val_acc\n",
    "            torch.save(model.state_dict(), './resnet18_fold{0}.pt'.format(flod_idx))\n",
    "            \n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-07-17T17:21:39.908717Z",
     "iopub.status.busy": "2021-07-17T17:21:39.908155Z",
     "iopub.status.idle": "2021-07-17T17:23:10.602840Z",
     "shell.execute_reply": "2021-07-17T17:23:10.602183Z",
     "shell.execute_reply.started": "2021-07-17T17:21:39.908669Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "test_jpg = glob.glob('./人脸情绪识别_数据集/test/*')\n",
    "test_jpg = np.array(test_jpg)\n",
    "test_jpg.sort()\n",
    "\n",
    "test_loader = torch.utils.data.DataLoader(\n",
    "        QRDataset(test_jpg,\n",
    "                transforms.Compose([\n",
    "                            transforms.RandomHorizontalFlip(),\n",
    "                            transforms.RandomVerticalFlip(),\n",
    "                            transforms.Resize((196, 196)),\n",
    "                            transforms.ToTensor(),\n",
    "                            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
    "            ])\n",
    "        ), batch_size=10, shuffle=False, num_workers=5, pin_memory=True\n",
    ")\n",
    "        \n",
    "model = XunFeiNet().cuda()\n",
    "model.load_state_dict(torch.load('resnet18_fold0.pt'))\n",
    "test_pred = predict(test_loader, model, 5)\n",
    "\n",
    "\n",
    "# test_csv = pd.DataFrame()\n",
    "# test_csv['ID'] = list(range(0, 3082))\n",
    "# test_csv['Label'] = np.argmax(test_pred, 1)\n",
    "# test_csv['Label'] = test_csv['Label'].map({1:'pos', 0:'neg'})\n",
    "# test_csv.to_csv('tmp.csv', index=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-07-17T17:23:14.441646Z",
     "iopub.status.busy": "2021-07-17T17:23:14.441203Z",
     "iopub.status.idle": "2021-07-17T17:23:14.462041Z",
     "shell.execute_reply": "2021-07-17T17:23:14.461500Z",
     "shell.execute_reply.started": "2021-07-17T17:23:14.441606Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "cls_name = np.array(['angry', 'disgusted', 'fearful', 'happy','neutral', 'sad', 'surprised'])\n",
    "submit_df = pd.DataFrame({'name': test_jpg, 'label': cls_name[test_pred.argmax(1)]})\n",
    "submit_df['name'] = submit_df['name'].apply(lambda x: x.split('/')[-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-07-17T17:23:15.144077Z",
     "iopub.status.busy": "2021-07-17T17:23:15.143694Z",
     "iopub.status.idle": "2021-07-17T17:23:15.197091Z",
     "shell.execute_reply": "2021-07-17T17:23:15.196170Z",
     "shell.execute_reply.started": "2021-07-17T17:23:15.144044Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "submit_df = submit_df.sort_values(by='name')\n",
    "submit_df.to_csv('pytorch_submit2.csv', index=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
